/*
 * PASE2.0 - Parallel Augmented Subspace Eigensolver 2.0
 * Copyright (C) 2025--Present by XieGroup and contributors. All rights reserved.
 * Released under the GNU Lesser General Public License 3.0 or later (LGPL-3.0+).
 * See <https://www.gnu.org/licenses/lgpl-3.0.html> for details.
 */
#include "multilevel.h"
#include "eigensolver.h"
#include "errorestimate.h"

#include "pase.h"
#include "app_slepc.h"
#include "petscmat.h"
#include "math.h"

#define BASE_FUNCTION C_T_P2_3D
#define TESTPASE 1 //2 for gcge , else for slepc
static char help[] = "Test with Laplace Mat.\n";

BOUNDARYTYPE BoundCond(INT bdid);
BOUNDARYTYPE MassBoundCond(INT bdid);
void BoundFun(double X[3], int dim, double *values);
void stiffmatrix(DOUBLE *left, DOUBLE *right, DOUBLE *coord, DOUBLE *AuxFEMValues, DOUBLE *value);
void massmatrix(DOUBLE *left, DOUBLE *right, DOUBLE *coord, DOUBLE *AuxFEMValues, DOUBLE *value);
void MatrixRead(Mat **A, Mat **B, Mat **P, int refinetime1, int refinetime2, int levelnum, MPI_Comm comm);

int main(int argc, char *argv[])
{
    SlepcInitialize(&argc, &argv, (char *)0, help);
    srand(1);

    int nev = 2000;
    int num_levels = 3;
    Mat *A_array, *B_array, *P_array;
    MatrixRead(&A_array, &B_array, &P_array, 4, 1, num_levels, MPI_COMM_WORLD);
    if (TESTPASE == 1)
    { 
        double *eval = NULL;
        void **evec = NULL;
        PASE_PARAMETER param;
        PASE_PARAMETER_Create(&param, num_levels, nev, 1e-8, PASE_GMG);
        param->A_array = (void **)A_array;
        param->B_array = (void **)B_array;
        param->P_array = (void **)P_array;
        param->initial_level = 1;
        param->aux_rtol = 1e-9;
        param->pc_type = PRECOND_A;
        param->if_batches = true;
        param->batch_size = 500;
        param->more_aux_nev = 60;
        param->more_batch_size = 70;

        double s_time = MPI_Wtime();
        PASE_EigenSolver(&eval, &evec, param);
        double e_time = MPI_Wtime();
        OpenPFEM_Print("The total time of PASE solving is %f sec\n", e_time - s_time);
        PASE_PARAMETER_Destroy(&param);
    }
    else if (TESTPASE == 2)
    {
        int flag = 0;
        double atol = 1e-2;
        double rtol = 1e-8;
        OpenPFEM_Print("开始GCGE, atol = %e, rtol = %e!\n", atol, rtol);
        double s_time = MPI_Wtime();
        PASE_DIRECT_GCGE((void *)A_array[2], (void *)B_array[2], flag, nev, atol, rtol, argc, argv);
        double e_time = MPI_Wtime();
        OpenPFEM_Print("The total time of GCGE solving is %f sec\n", e_time - s_time);
    }
    else
    {
        int flag = 0;
        double atol = 1e-2;
        double rtol = 1e-8;
        OpenPFEM_Print("开始SLEPC, atol = %e, rtol = %e!\n", atol, rtol);
        flag = 2; //flag = 2: krylov-Schur, flag = 6: lobpcg
        double s_time = MPI_Wtime();
        PASE_DIRECT_EPS((void *)A_array[2], (void *)B_array[2], flag, nev, rtol,argc, argv);
        double e_time = MPI_Wtime();
        OpenPFEM_Print("The total time of Slepc solving is %f sec\n", e_time - s_time);
    }
    SlepcFinalize();
    return 0;
}
void MatrixRead(Mat **A, Mat **B, Mat **P, int refinetime1, int refinetime2, int levelnum, MPI_Comm comm)
{
    MESH *mesh = NULL;
    MeshCreate(&mesh, 3, comm);
    MeshBuild(mesh, "../pase/data/dataCube5.txt", SIMPLEX, TETHEDRAL); 
    MeshUniformRefine(mesh, refinetime1);                     
    MeshPartition(mesh);
    MeshUniformRefine(mesh, refinetime2);

    MULTIINDEX stiffLmultiindex[3] = {D100, D010, D001}, stiffRmultiindex[3] = {D100, D010, D001};
    MULTIINDEX massLmultiindex[1] = {D000}, massRmultiindex[1] = {D000};
    QUADRATURE *Quadrature = QuadratureBuild(QuadTetrahedral56);

    MESH *finermesh = NULL, *coarsemesh = NULL;
    FEMSPACE *finerspace = NULL, *coarsespace = NULL;
    FEMSPACE *massfinerspace = NULL, *masscoarsespace = NULL;
    DISCRETEFORM *StiffDiscreteForm = NULL, *MassDiscreteForm = NULL;
    MATRIX **stiffmatrices = (MATRIX **)malloc(levelnum * sizeof(MATRIX *));
    MATRIX **massmatrices = (MATRIX **)malloc(levelnum * sizeof(MATRIX *));
    MATRIX **prolongs = (MATRIX **)malloc((levelnum - 1) * sizeof(MATRIX *));
    INT i;
    for (i = 0; i < levelnum; i++)
    {
        if (i == 0)
        {
            coarsemesh = mesh;
            coarsespace = FEMSpaceBuild(coarsemesh, BASE_FUNCTION, BoundCond);
            masscoarsespace = FEMSpaceBuild(coarsemesh, BASE_FUNCTION, MassBoundCond);
            StiffDiscreteForm = DiscreteFormBuild(coarsespace, 3, stiffLmultiindex, coarsespace, 3, stiffRmultiindex,
                                                  stiffmatrix, NULL, BoundFun, Quadrature);
            MassDiscreteForm = DiscreteFormBuild(masscoarsespace, 1, massLmultiindex, masscoarsespace, 1, massRmultiindex,
                                                 massmatrix, NULL, BoundFun, Quadrature);
            stiffmatrices[i] = NULL, massmatrices[i] = NULL;
            MatrixAssemble(&(stiffmatrices[i]), NULL, StiffDiscreteForm, TYPE_OPENPFEM);
            MatrixAssemble(&(massmatrices[i]), NULL, MassDiscreteForm, TYPE_OPENPFEM);
            MatrixDeleteDirichletBoundary(stiffmatrices[i], StiffDiscreteForm);
            MatrixDeleteDirichletBoundary(massmatrices[i], MassDiscreteForm);
            MatrixConvert(stiffmatrices[i], TYPE_PETSC, 0);
            MatrixConvert(massmatrices[i], TYPE_PETSC, 0);
            DiscreteFormDestroy(&StiffDiscreteForm);
            DiscreteFormDestroy(&MassDiscreteForm);
        }
        else
        {
            finermesh = MeshDuplicate(coarsemesh);
            MeshUniformRefine(finermesh, 1);
            finerspace = FEMSpaceBuild(finermesh, BASE_FUNCTION, BoundCond);
            massfinerspace = FEMSpaceBuild(finermesh, BASE_FUNCTION, MassBoundCond);
            StiffDiscreteForm = DiscreteFormBuild(finerspace, 3, stiffLmultiindex, finerspace, 3, stiffRmultiindex,
                                                  stiffmatrix, NULL, BoundFun, Quadrature);
            MassDiscreteForm = DiscreteFormBuild(massfinerspace, 1, massLmultiindex, massfinerspace, 1, massRmultiindex,
                                                 massmatrix, NULL, BoundFun, Quadrature);
            stiffmatrices[i] = NULL, massmatrices[i] = NULL, prolongs[i - 1] = NULL;
            MatrixAssemble(&(stiffmatrices[i]), NULL, StiffDiscreteForm, TYPE_OPENPFEM);
            MatrixAssemble(&(massmatrices[i]), NULL, MassDiscreteForm, TYPE_OPENPFEM);
            MatrixDeleteDirichletBoundary(stiffmatrices[i], StiffDiscreteForm);
            MatrixDeleteDirichletBoundary(massmatrices[i], MassDiscreteForm);
            MatrixConvert(stiffmatrices[i], TYPE_PETSC, 0);
            MatrixConvert(massmatrices[i], TYPE_PETSC, 0);
            ProlongMatrixAssemble(&(prolongs[i - 1]), coarsespace, finerspace, TYPE_OPENPFEM);
            ProlongDeleteDirichletBoundary(prolongs[i - 1], coarsespace, finerspace);
            MatrixConvert(prolongs[i - 1], TYPE_PETSC, 0);
            MeshDestroy(&coarsemesh);
            FEMSpaceDestroy(&coarsespace);
            DiscreteFormDestroy(&StiffDiscreteForm);
            DiscreteFormDestroy(&MassDiscreteForm);
            coarsemesh = finermesh;
            coarsespace = finerspace;
        }
    }

    MeshDestroy(&coarsemesh);
    FEMSpaceDestroy(&coarsespace);
    FEMSpaceDestroy(&masscoarsespace);
    QuadratureDestroy(&Quadrature);

    Mat *A_array = (Mat *)malloc(levelnum * sizeof(Mat));
    Mat *B_array = (Mat *)malloc(levelnum * sizeof(Mat));
    int size, rank;
    MPI_Comm_size(MPI_COMM_WORLD, &size);
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    int *localsize = (int *)malloc(levelnum * sizeof(int));
    int *maxsize = (int *)malloc(levelnum * sizeof(int));
    int *minsize = (int *)malloc(levelnum * sizeof(int));
    for (i = 0; i < levelnum; i++)
    {
        A_array[i] = (Mat)(stiffmatrices[i]->data_petsc);
        B_array[i] = (Mat)(massmatrices[i]->data_petsc);
        localsize[i] = stiffmatrices[i]->local_nrows;
    }
    MPI_Reduce(localsize, maxsize, levelnum, MPI_INT, MPI_MAX, 0, MPI_COMM_WORLD);
    MPI_Reduce(localsize, minsize, levelnum, MPI_INT, MPI_MIN, 0, MPI_COMM_WORLD);
    free(localsize);
    free(maxsize);
    free(minsize);
    Mat *P_array = (Mat *)malloc((levelnum - 1) * sizeof(Mat));
    for (i = 0; i < levelnum - 1; i++)
    {
        P_array[i] = (Mat)(prolongs[i]->data_petsc);
    }
    *A = A_array;
    *B = B_array;
    *P = P_array;
}

BOUNDARYTYPE BoundCond(INT bdid)
{
    if (bdid > 0)
    {
        return DIRICHLET;
    }
    else
    {
        return INNER;
    }
}

BOUNDARYTYPE MassBoundCond(INT bdid)
{
    if (bdid > 0)
    {
        return MASSDIRICHLET;
    }
    else
    {
        return INNER;
    }
}

void BoundFun(double X[3], int dim, double *values)
{
    values[0] = 0.0;
}

void stiffmatrix(DOUBLE *left, DOUBLE *right, DOUBLE *coord, DOUBLE *AuxFEMValues, DOUBLE *value)
{
    value[0] = (left[0] * right[0] + left[1] * right[1] + left[2] * right[2]);
}

void massmatrix(DOUBLE *left, DOUBLE *right, DOUBLE *coord, DOUBLE *AuxFEMValues, DOUBLE *value)
{
    value[0] = left[0] * right[0];
}
